{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jeeves/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import warnings\n",
    "from typing import List, Optional, Tuple, Union, Dict\n",
    "from collections import OrderedDict\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import re\n",
    "from dataclasses import dataclass\n",
    "\n",
    "\n",
    "import logging\n",
    "from configuration_minicpm import MiniCPMConfig  # 直接导入\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MiniCPM 采用标准的 Decoder 作为其架构，主要包括三个部分：Embedding, Attention 和 MLP 层。我们对每一部分进行拆解，以便更好地理解其工作原理。整体代码源自于 [MiniCPM 官方仓库](https://github.com/OpenBMB/MiniCPM)，这里逐步搭建模型，以便更好地理解其工作原理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = MiniCPMConfig(**json.load(open(\"config.json\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class BaseModelOutputWithPast(OrderedDict):\n",
    "    last_hidden_state: torch.FloatTensor = None\n",
    "    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n",
    "    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n",
    "    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None\n",
    "    \n",
    "@dataclass\n",
    "class CausalLMOutputWithPast(OrderedDict):\n",
    "    loss: Optional[torch.FloatTensor] = None\n",
    "    logits: torch.FloatTensor = None\n",
    "    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n",
    "    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n",
    "    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RoPE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在计算 Embedding 时，采用了 RoPE（Rotary Positional Embedding）的相对位置编码方式，帮助模型更好地理解序列中的位置信息。RoPE 的核心思想是将位置编码的计算转换为旋转矩阵的计算，从而减少计算量。RoPE 的计算公式如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MiniCPMRotaryEmbedding` 实现了旋转位置嵌入（Rotary Position Embedding）。它计算并缓存旋转位置编码的余弦和正弦值，以便在前向传播过程中快速获取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MiniCPMRotaryEmbedding(nn.Module):\n",
    "    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n",
    "        super().__init__()\n",
    "\n",
    "        self.dim = dim\n",
    "        self.max_position_embeddings = max_position_embeddings\n",
    "        self.base = base\n",
    "        # 计算了逆频率inv_freq并使用register_buffer方法将其注册为一个缓冲区\n",
    "        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n",
    "        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n",
    "\n",
    "        # 构建缓存\n",
    "        self._set_cos_sin_cache(\n",
    "            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32\n",
    "        )\n",
    "\n",
    "    def _set_cos_sin_cache(self, seq_len, device, dtype):\n",
    "        # 计算并缓存余弦和正弦值\n",
    "        self.max_seq_len_cached = seq_len\n",
    "        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n",
    "        freqs = torch.outer(t, self.inv_freq)\n",
    "\n",
    "        # 将频率扩展到维度上\n",
    "        emb = torch.cat((freqs, freqs), dim=-1)\n",
    "\n",
    "        # 缓存余弦值和正弦值\n",
    "        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n",
    "        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n",
    "\n",
    "    def forward(self, x, seq_len=None):\n",
    "        # 首先检查输入序列的长度是否超过了缓存的最大长度，如果超过了，则重新计算并缓存余弦和正弦值\n",
    "        # x: [bs, num_attention_heads, seq_len, head_size]\n",
    "        if seq_len > self.max_seq_len_cached:\n",
    "            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n",
    "\n",
    "        # 返回对应序列长度的余弦和正弦值\n",
    "        return (\n",
    "            self.cos_cached[:seq_len].to(dtype=x.dtype),\n",
    "            self.sin_cached[:seq_len].to(dtype=x.dtype),\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此段代码的功能是对输入数据的一半隐藏维度进行旋转操作。将原本的后半部分旋转到前面，将原本的前半部分旋转到后面。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rotate_half(x):\n",
    "    # 将输入张量 x 沿 emb 维度一分为二\n",
    "    x1 = x[..., : x.shape[-1] // 2]\n",
    "    x2 = x[..., x.shape[-1] // 2 :]\n",
    "    # 将后半部分取负号，然后与前半部分拼接，对输入张量的隐藏维度进行旋转\n",
    "    return torch.cat((-x2, x1), dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此函数将旋转位置嵌入（Rotary Position Embedding）应用于查询和键张量。首先，函数获取键张量的数据类型，然后根据位置索引提取旋转嵌入的余弦和正弦部分，并在指定维度上进行扩展。为了提高计算的精度，在进行 embedding 计算时，从 bfloat16 数据类型转换为 float32 数据类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n",
    "    # 保存原始数据类型\n",
    "    orig_dtype = k.dtype  # torch.bfloat16\n",
    "    \n",
    "    # 根据 position_ids 选择 cos 和 sin，并在指定维度上扩展\n",
    "    cos = cos[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim] 便于和[bs, num_heads, q_len, head_dim] 维度的 q,k 进行矩阵乘法\n",
    "    sin = sin[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim]\n",
    "    \n",
    "    # 将 q 和 k 转换为 float32 类型，以便进行精确的计算\n",
    "    q_fp32 = q.to(dtype=torch.float32, device=q.device)\n",
    "    k_fp32 = k.to(dtype=torch.float32, device=k.device)\n",
    "    \n",
    "    # 计算 q 和 k 的旋转位置嵌入\n",
    "    q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)\n",
    "    k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)\n",
    "    \n",
    "    # 将结果转换回原始数据类型并返回\n",
    "    return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)  # [bs, num_heads, q_len, head_dim]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在语言模型中，未来的 token 在当前时间步骤中是不可见的。因此，我们构造一个上三角矩阵来屏蔽未来的信息。在此矩阵中，对角线以上的部分（即未来的元素）被设置为极小的浮点数值（通常为负无穷大），这样做的目的是在自注意力机制的计算过程中，使这些部分被忽略或仅被赋予极小的权重，从而确保模型仅能“感知”到之前的元素。若存在缓存，则需要将过去的缓存纳入考虑范围。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_causal_mask(input_shape, dtype, device, past_length=0):\n",
    "    batch_size, query_length = input_shape\n",
    "    # 创建一个上三角矩阵，填充最小浮点值，表示未来的token不能看到\n",
    "    causal_mask = torch.triu(torch.full((query_length, query_length), torch.finfo(dtype).min, dtype=dtype, device=device), diagonal=1)\n",
    "    # 如果有过去的key-value长度，则在mask前面添加零矩阵\n",
    "    if past_length > 0:\n",
    "        causal_mask = torch.cat([torch.zeros(query_length, past_length, dtype=dtype, device=device), causal_mask], dim=-1)\n",
    "    # 扩展mask的维度以匹配批次大小，并返回\n",
    "    return causal_mask[None, None, :, :].expand(batch_size, 1, query_length, query_length + past_length)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 MiniCPM 模型中，原始的分词器（tokenizer）生成的掩码（mask）矩阵是一个二维矩阵，其中0表示填充（padding）位置，1表示真实令牌（token）位置。在注意力（attention）层中，我们需要将这个掩码矩阵扩展到四维，以便它能够与注意力矩阵进行逐元素相乘。这一步骤是为了确保模型在计算注意力权重时，只考虑真实令牌的位置，而忽略填充位置，从而提高模型处理不同长度输入序列的能力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def expand_attention_mask(mask, dtype, target_length = None):\n",
    "    batch_size, source_length = mask.shape\n",
    "    target_length = target_length if target_length is not None else source_length\n",
    "\n",
    "    # 扩展mask的维度以匹配目标长度和批次大小\n",
    "    expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_length, source_length).to(dtype)\n",
    "    # 反转mask，将1变为0，0变为1\n",
    "    inverted_mask = 1.0 - expanded_mask\n",
    "    # 将反转后的mask中为True的位置填充为最小浮点值\n",
    "    return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "组合我们设计好的用于因果语言模型的 mask 和 padding mask，得到最终的 mask 矩阵。这个矩阵的作用是在自注意力机制中，屏蔽未来的信息，确保模型只能“感知”到之前的元素。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def prepare_4d_causal_attention_mask(\n",
    "    attention_mask: Optional[torch.Tensor],\n",
    "    query_length: int,\n",
    "    past_length: int,\n",
    "    dtype: torch.dtype,\n",
    "    device: Union[torch.device, \"str\"] = \"cpu\",\n",
    "):\n",
    "\n",
    "    # 如果attention_mask存在且是2维的\n",
    "    if attention_mask is not None and attention_mask.dim() == 2:\n",
    "        # 获取批次大小和查询长度\n",
    "        batch_size = attention_mask.shape[0]\n",
    "        query_length = query_length\n",
    "        # 更新input_shape和past_length\n",
    "        input_shape = (batch_size, query_length)\n",
    "        causal_mask = None\n",
    "        if query_length > 1:\n",
    "            # 创建4维的causal mask\n",
    "            causal_mask = create_causal_mask(input_shape, dtype, device, past_length)\n",
    "        # 扩展attention mask\n",
    "        expanded_mask = expand_attention_mask(attention_mask, dtype, query_length)\n",
    "        if causal_mask is not None:\n",
    "            # 将causal mask中对应expanded mask为True的位置填充为最小浮点值\n",
    "            expanded_attn_mask = causal_mask.masked_fill(expanded_mask.bool(), torch.finfo(dtype).min)\n",
    "        expanded_attn_mask = expanded_mask\n",
    "    return expanded_attn_mask\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MiniCPMAttention` 通过多头注意力机制高效处理长序列数据。它融合了动态头维度分配、旋转式位置编码（RoPE）、以及键值对缓存机制等多项技术，以提高模型的性能和灵活性。\n",
    "\n",
    "- **动态头维度分配**：通过将隐藏层的维度均匀分配给多个注意力头，实现了并行处理的优化，从而提高了计算效率。\n",
    "- **RoPE 位置编码**：引入了旋转式位置编码，以增强模型对序列位置信息的捕捉能力。这在处理长序列时尤其重要，因为它能够有效地保持位置信息的连续性和一致性。\n",
    "- **键值对缓存机制**：在自回归解码过程中，支持缓存先前计算的键值对，这一机制显著加速了连续解码任务的处理速度。\n",
    "\n",
    "相比如原始的 Attention，MiniCPMAttention 在计算 Embeddig 时采用 RoPE Embedding，这样可以更好地处理长序列。另外，MiniCPMAttention 支持键值对的缓存，这在自回归解码中非常有用，可以大大提高解码速度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MiniCPMAttention(nn.Module):\n",
    "    def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.layer_idx = layer_idx\n",
    "        if layer_idx is None:\n",
    "            layer_idx.warn_once(\n",
    "                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n",
    "                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n",
    "                \"when creating this class.\"\n",
    "            )\n",
    "\n",
    "        self.attention_dropout = config.attention_dropout # 0.0\n",
    "        self.hidden_size = config.hidden_size # 2304\n",
    "        self.num_heads = config.num_attention_heads # 36\n",
    "        self.head_dim = self.hidden_size // self.num_heads # 64\n",
    "        self.num_key_value_heads = config.num_key_value_heads # 36\n",
    "        self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 1\n",
    "        self.max_position_embeddings = config.max_position_embeddings # 2048\n",
    "        self.rope_theta = config.rope_theta  # 10000.0\n",
    "        self.is_causal = True\n",
    "\n",
    "        if (self.head_dim * self.num_heads) != self.hidden_size:\n",
    "            raise ValueError(\n",
    "                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n",
    "                f\" and `num_heads`: {self.num_heads}).\"\n",
    "            )\n",
    "\n",
    "        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) # (2304, 36*64=2304)\n",
    "        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)\n",
    "        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)\n",
    "        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)\n",
    "        self._init_rope()\n",
    "\n",
    "    def _init_rope(self):\n",
    "        self.rotary_emb = MiniCPMRotaryEmbedding(\n",
    "            self.head_dim,\n",
    "            max_position_embeddings=self.max_position_embeddings,\n",
    "            base=self.rope_theta,\n",
    "        )\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states: torch.Tensor,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_value:  Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n",
    "        output_attentions: bool = False,\n",
    "        use_cache: bool = False,\n",
    "        **kwargs,\n",
    "    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n",
    "\n",
    "        bsz, q_len, _ = hidden_states.size()\n",
    "\n",
    "        # q,k,v 矩阵\n",
    "        query_states = self.q_proj(hidden_states)\n",
    "        key_states = self.k_proj(hidden_states)\n",
    "        value_states = self.v_proj(hidden_states)\n",
    "        \n",
    "        # 拆成 num_heads 个头 (bsz, num_heads, q_len, self.head_dim)\n",
    "        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
    "        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
    "\n",
    "        kv_seq_len = key_states.shape[-2]\n",
    "        if past_key_value is not None and len(past_key_value) > 0 and len(past_key_value[0]) > self.layer_idx and len(past_key_value[0][self.layer_idx].shape) > 1:\n",
    "            # 如果有 kv-cache 缓存，需要加上缓存的长度\n",
    "            kv_seq_len += past_key_value[0][self.layer_idx].shape[0] \n",
    "            \n",
    "        # 获取 RoPE Embedding 对应位置的 cos 和 sin 值 （ 这里传入的 value_states 不会参与计算，只是确保类型和设备）\n",
    "        cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)\n",
    "        \n",
    "        # 对 q 和 k 向量应用 RoPE 位置编码\n",
    "        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n",
    "        # 如果存在先前的 k-v 缓存\n",
    "        if past_key_value is not None:\n",
    "            # 若当前层缓存未初始化，则进行初始化\n",
    "            if len(past_key_value[0]) <= self.layer_idx:\n",
    "                # 为当前层新增 k-v 的缓存\n",
    "                past_key_value[0].append(key_states)\n",
    "                past_key_value[1].append(value_states)\n",
    "            else:\n",
    "                # 若当前层缓存已存在，通过在序列长度维度上进行拼接更新缓存\n",
    "                past_key_value[0][self.layer_idx] = torch.cat([past_key_value[0][self.layer_idx], key_states], dim=-2)\n",
    "                past_key_value[1][self.layer_idx] = torch.cat([past_key_value[1][self.layer_idx], value_states], dim=-2)\n",
    "\n",
    "            key_states, value_states = past_key_value[0][self.layer_idx], past_key_value[1][self.layer_idx]   \n",
    "            \n",
    "        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
    "        \n",
    "        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n",
    "            raise ValueError(\n",
    "                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n",
    "                f\" {attn_weights.size()}\"\n",
    "            )\n",
    "\n",
    "        if attention_mask is not None:\n",
    "            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n",
    "                raise ValueError(\n",
    "                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n",
    "                )\n",
    "            attn_weights = attn_weights + attention_mask\n",
    "\n",
    "        # 使用32位浮点数精度以提高计算精度\n",
    "        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n",
    "        attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n",
    "        attn_output = torch.matmul(attn_weights, value_states)\n",
    "\n",
    "        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n",
    "            raise ValueError(\n",
    "                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n",
    "                f\" {attn_output.size()}\"\n",
    "            )\n",
    "            \n",
    "        attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "\n",
    "        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
    "\n",
    "        attn_output = self.o_proj(attn_output)\n",
    "\n",
    "        if not output_attentions:\n",
    "            attn_weights = None\n",
    "        \n",
    "        return attn_output, attn_weights, past_key_value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在模型中，注意力层（attention layer）占据了大部分的参数量，这主要归因于多个注意力头（attention heads）的参数。其中，查询（Q）、键（K）、值（V）三个矩阵的参数量相同。给定隐藏层大小（hidden_size）为 2304，并使用 64 个注意力头，每个头的维度设置为 36，那么这三个矩阵的总参数量计算为 `3*2304*36*64=15,925,248`。\n",
    "\n",
    "此外，还需要一个映射矩阵将这 64 个头的输出重新映射回输入的维度，该映射矩阵的参数量为 `2304*2304=5,308,416`。\n",
    "\n",
    "因此，注意力层的总参数量为 `15,925,248 + 5,308,416 = 21,233,664` 约 21M。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RMSNorm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`rms_layernorm` 是一种归一化层，它结合了 RMSProp 优化器和 Layer Normalization 的概念。可以对输入进行归一化处理，使得网络在训练过程中更加稳定。\n",
    "\n",
    "$$ y = W \\times \\left(\\frac{H}{\\sqrt{mean(H^2) + \\epsilon}}\\right) $$\n",
    "\n",
    "`rms_layernorm`层首先计算输入的平方的均值，然后用输入除以这个均值的平方根（加上一个很小的常数以防止除以零），从而确保输入的每个元素都在一个相对稳定的范围内。然后，这个层会乘以一个可学习的权重参数。\n",
    "\n",
    "这种归一化策略有助于减少训练过程中的内部协变量偏移，降低模型对初始化的敏感度，同时也能加速训练过程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class MiniCPMRMSNorm(nn.Module):\n",
    "    def __init__(self, hidden_size, eps=1e-6):\n",
    "        super().__init__()\n",
    "        # 初始化权重参数为1，形状由hidden_size决定\n",
    "        self.weight = nn.Parameter(torch.ones(hidden_size)) \n",
    "        # 设置方差的epsilon值，防止除以0\n",
    "        self.variance_epsilon = eps\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        # 保存输入的数据类型，以便后续恢复\n",
    "        old_dtype = hidden_states.dtype\n",
    "        # 计算方差，先转换数据类型以提高精度，然后计算平方的均值\n",
    "        variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)\n",
    "        # 标准化隐藏状态，使用rsqrt（方差+epsilon的倒数根）进行缩放，并恢复原数据类型\n",
    "        hidden_states = (hidden_states * torch.rsqrt(variance + self.variance_epsilon)).to(old_dtype)\n",
    "        # 应用权重参数，进行缩放\n",
    "        return hidden_states * self.weight\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SwiGLU 的 MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MiniCPM 的 MLP（多层感知器）结构采用 SwiGLU 激活层。该结构包含三个线性层：gate_proj、up_proj 和 down_proj，以及一个 SiLU 激活函数。将 gate_proj 层的结果通过 SiLU 激活函数转化，控制 up_proj 层的激活权重，对输入 x 进行特征提取和转换，然后通过 down_proj 层将转换后的特征映射回原始维度，从而实现一次前向传播。这种设计策略使得模型在保持输出维度不变的同时，能够有效地提取和转换输入特征。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SiLU（Sigmoid Linear Unit）激活函数是一种非线性函数，其公式为 $$ f(x) = x \\cdot \\sigma(x) $$当输入值为负时，该函数的输出接近于0；而当输入值为正时，输出则接近于输入值本身。这种特性使得 SiLU 函数具有无上界、有下界、平滑且非单调的特征。在深度学习模型的众多实践中，SiLU 函数已被证明在性能上超越了 ReLU 及其他激活函数。SiLU 函数不仅继承了 ReLU 激活函数的优点（例如，能够有效缓解梯度消失问题），同时也克服了 ReLU 函数的一些不足（例如，ReLU 函数在负数部分梯度为零，且非零中心）。此外，SiLU 函数是一种平滑函数，这意味着在其整个定义域内都存在导数，这对于优化过程是极其有利的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "  \n",
    "class MiniCPMMLP(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.hidden_size = config.hidden_size # 2304\n",
    "        self.intermediate_size = config.intermediate_size # 5760\n",
    "        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n",
    "        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n",
    "        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n",
    "        self.act_fn = nn.SiLU()\n",
    "\n",
    "    def forward(self, x): \n",
    "        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n",
    "        return down_proj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MLP 层是模型参数量的另一个重要来源。在 MiniCPM 模型中，MLP 层的参数量主要来自于三个线性层（gate_proj、up_proj 和 down_proj）的参数。给定隐藏层大小（hidden_size）为 2304，up_proj 和 gate_proj 将均数据升维到 5760，down_proj 再降维到 2304，那么这三个线性层的参数量分别为 `2304*5760=13,276,160`，`2304*5760=13,276,160`，`5760*2304=13,276,160`。\n",
    "\n",
    "MLP 层的总参数量为 `13,276,160 + 13,276,160 + 13,276,160 = 39,828,480`，约 39M。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DecoderLayer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在构建 MiniCPM 模型的解码器层 `MiniCPMDecoderLayer` 时，我们将充分利用已经构建的关键组件：`MiniCPMAttention` 类负责执行注意力计算，`MiniCPMMLP` 类处理全连接层的运算，而 `MiniCPMRMSNorm` 类则负责执行层归一化操作，这包括对输入的隐藏状态进行归一化以及在注意力计算之后进行归一化处理。\n",
    "\n",
    "解码器层的处理流程遵循了解码器层设计的通用模式。首先，对输入的隐藏状态进行层归一化处理，接着通过自注意力机制对其进行加工处理。处理后的隐藏状态会与原始的隐藏状态进行残差连接，然后进行比例缩放。之后，对这个经过残差连接和比例缩放处理的隐藏状态再次进行层归一化处理，并通过全连接层进行加工处理。处理后的隐藏状态再次与原始的隐藏状态进行残差连接，并进行比例缩放。\n",
    "\n",
    "\n",
    "在深层神经网络中，随着层数的增加，残差连接的累积可能导致梯度爆炸或梯度消失的问题。通过引入缩放机制，可以确保每一层的输出保持在一个合理的范围内，从而提升训练过程的稳定性和模型的整体性能。通过缩放因子 `self.scale_depth / math.sqrt(self.num_hidden_layers)` 调整残差连接的贡献度，以确保每一层的输出既不会因层数增加而过大，也不会过小。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class MiniCPMDecoderLayer(nn.Module):\n",
    "    def __init__(self, config: MiniCPMConfig, layer_idx: int):\n",
    "        super().__init__()\n",
    "        self.hidden_size = config.hidden_size\n",
    "        self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)\n",
    "\n",
    "        self.mlp = MiniCPMMLP(config)\n",
    "        self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
    "        self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
    "\n",
    "        self.scale_depth = config.scale_depth\n",
    "        self.num_hidden_layers = config.num_hidden_layers\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states: torch.Tensor,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n",
    "        output_attentions: Optional[bool] = False,\n",
    "        use_cache: Optional[bool] = False,\n",
    "        **kwargs,\n",
    "    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n",
    "        \n",
    "        residual = hidden_states\n",
    "        # 对输入归一化\n",
    "        hidden_states = self.input_layernorm(hidden_states)\n",
    "        # Self Attention 计算\n",
    "        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n",
    "            hidden_states=hidden_states,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            past_key_value=past_key_value,\n",
    "            output_attentions=output_attentions,\n",
    "            use_cache=use_cache,\n",
    "            **kwargs,\n",
    "        )\n",
    "        # 应用残差连接并缩放\n",
    "        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))\n",
    "\n",
    "        residual = hidden_states\n",
    "        # 对 attention 结果归一化\n",
    "        hidden_states = self.post_attention_layernorm(hidden_states)\n",
    "\n",
    "        hidden_states = self.mlp(hidden_states)\n",
    "        # 应用残差连接并缩放\n",
    "        hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))\n",
    "\n",
    "        outputs = (hidden_states,)\n",
    "\n",
    "        if output_attentions:\n",
    "            outputs += (self_attn_weights,)\n",
    "\n",
    "        if use_cache:\n",
    "            outputs += (present_key_value,)\n",
    "\n",
    "        return outputs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "每个解码层都由 attention 层和 MLP 层组成，所以一个解码器的参数量为 `21,233,664 + 39,828,480 = 61,062,144` 约 61M。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用一个  Model 类进行所有 MiniCPM 的基本配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "  \n",
    "class MiniCPMPreTrainedModel(nn.Module):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        self.config = args[0]\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "    def _init_weights(self, module):\n",
    "        std = self.config.initializer_range\n",
    "        if isinstance(module, nn.Linear):\n",
    "            module.weight.data.normal_(mean=0.0, std=std)\n",
    "            if module.bias is not None:\n",
    "                module.bias.data.zero_()\n",
    "        elif isinstance(module, nn.Embedding):\n",
    "            module.weight.data.normal_(mean=0.0, std=std)\n",
    "            if module.padding_idx is not None:\n",
    "                module.weight.data[module.padding_idx].zero_()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MiniCPMModel 使整个模型的核心部分，它负责整个模型的前向计算过程。包括以下几个关键步骤：\n",
    "\n",
    "1. **参数校验**：确保`input_ids`和`inputs_embeds`不会同时指定。\n",
    "2. **位置ID处理**：若未提供`position_ids`，则自动创建一个序列。\n",
    "3. **词嵌入生成**：基于`input_ids`生成词嵌入，或直接采用`inputs_embeds`。\n",
    "4. **注意力掩码准备**：构造一个四维的因果注意力掩码。\n",
    "5. **隐藏状态初始化**：以词嵌入向量初始化隐藏状态。\n",
    "\n",
    "在完成隐藏状态的初始化后，模型通过若干解码器层对隐藏状态进行加工处理。在此过程中，根据需求，模型能够输出隐藏状态和注意力机制的详细信息。这包括对最终层隐藏状态的归一化处理，以及对所有隐藏状态和自注意力机制输出的汇总。此外，还涉及到批次大小和序列长度的计算、缓存机制的管理、位置索引的生成、词嵌入层的操作、解码器层的加工处理，以及最终输出层的归一化处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class MiniCPMModel(MiniCPMPreTrainedModel):\n",
    "\n",
    "    def __init__(self, config: MiniCPMConfig):\n",
    "        super().__init__(config)\n",
    "\n",
    "        self.padding_idx = config.pad_token_id\n",
    "        self.vocab_size = config.vocab_size\n",
    "\n",
    "        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n",
    "        self.layers = nn.ModuleList(\n",
    "            [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n",
    "        )\n",
    "\n",
    "        self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n",
    "\n",
    "        self.gradient_checkpointing = False\n",
    "        # self._init_weights()\n",
    "        \n",
    "    def _init_weights(self, module):\n",
    "        std = self.config.initializer_range\n",
    "        if isinstance(module, nn.Linear):\n",
    "            module.weight.data.normal_(mean=0.0, std=std)\n",
    "            if module.bias is not None:\n",
    "                module.bias.data.zero_()\n",
    "        elif isinstance(module, nn.Embedding):\n",
    "            module.weight.data.normal_(mean=0.0, std=std)\n",
    "            if module.padding_idx is not None:\n",
    "                module.weight.data[module.padding_idx].zero_()\n",
    "            \n",
    "    def get_input_embeddings(self):\n",
    "        return self.embed_tokens\n",
    "\n",
    "    def set_input_embeddings(self, value):\n",
    "        self.embed_tokens = value\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ) -> Union[Tuple, BaseModelOutputWithPast]:\n",
    "        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
    "        output_hidden_states = (\n",
    "            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
    "        )\n",
    "        use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        if input_ids is not None and inputs_embeds is not None:\n",
    "            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n",
    "        elif input_ids is not None:\n",
    "            batch_size, seq_length = input_ids.shape[:2]\n",
    "        elif inputs_embeds is not None:\n",
    "            batch_size, seq_length = inputs_embeds.shape[:2]\n",
    "        else:\n",
    "            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n",
    "\n",
    "        past_key_values_length = 0\n",
    "        \n",
    "        if use_cache:\n",
    "            if past_key_values is not None and len(past_key_values) > 0 and len(past_key_values[0]) > 0 and len(past_key_values[0][0].shape) > 2:\n",
    "                past_key_values_length = past_key_values[0][0].shape[-2]\n",
    "\n",
    "        if position_ids is None:\n",
    "            device = input_ids.device if input_ids is not None else inputs_embeds.device\n",
    "            position_ids = torch.arange(\n",
    "                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n",
    "            )\n",
    "            position_ids = position_ids.unsqueeze(0)\n",
    "\n",
    "        if inputs_embeds is None:\n",
    "            inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb\n",
    "\n",
    "        attention_mask = prepare_4d_causal_attention_mask(attention_mask, seq_length, past_key_values_length, inputs_embeds.dtype, inputs_embeds.device)\n",
    "        \n",
    "        # embed positions\n",
    "        hidden_states = inputs_embeds\n",
    "\n",
    "        # decoder layers\n",
    "        all_hidden_states = () if output_hidden_states else None\n",
    "        all_self_attns = () if output_attentions else None\n",
    "        next_decoder_cache = None\n",
    "\n",
    "        for decoder_layer in self.layers:\n",
    "            if output_hidden_states:\n",
    "                all_hidden_states += (hidden_states,)\n",
    "\n",
    "            layer_outputs = decoder_layer(\n",
    "                hidden_states,\n",
    "                attention_mask=attention_mask,\n",
    "                position_ids=position_ids,\n",
    "                past_key_value=past_key_values,\n",
    "                output_attentions=output_attentions,\n",
    "                use_cache=use_cache,\n",
    "            )\n",
    "\n",
    "            hidden_states = layer_outputs[0]\n",
    "\n",
    "            if use_cache:\n",
    "                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n",
    "\n",
    "            if output_attentions:\n",
    "                all_self_attns += (layer_outputs[1],)\n",
    "        # 对最终的结果归一化\n",
    "        hidden_states = self.norm(hidden_states)\n",
    "\n",
    "        # 添加最后一个解码器层的隐藏状态\n",
    "        if output_hidden_states:\n",
    "            all_hidden_states += (hidden_states,)\n",
    "\n",
    "        next_cache = None\n",
    "        if use_cache:\n",
    "            next_cache = next_decoder_cache\n",
    "        if not return_dict:\n",
    "            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)\n",
    "        return BaseModelOutputWithPast(\n",
    "            last_hidden_state=hidden_states,\n",
    "            past_key_values=next_cache,\n",
    "            hidden_states=all_hidden_states,\n",
    "            attentions=all_self_attns,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Embedding 占模型中非常大的一个参数量，这里的为 `122753 * 2304 = 282,822,912`，即约 282M 参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先定义一个 `CausalLMOutputWithPast`类，主要用于因果语言模型（或自回归模型）的输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalLMOutputWithPast(OrderedDict):\n",
    "    loss: Optional[torch.FloatTensor] = None\n",
    "    logits: torch.FloatTensor = None\n",
    "    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None\n",
    "    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n",
    "    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们来看一下如何准备输入数据的步骤。\n",
    "\n",
    "1. **调整输入数据**：利用`adjust_input_ids`函数，根据提供的注意力掩码或之前计算出的键值对长度，调整`input_ids`的长度，以确保其符合模型所期望的长度，从而能够正确地应用注意力机制。\n",
    "\n",
    "2. **处理先前的键值对**：计算先前键值对的长度，并基于此调整`input_ids`和`attention_mask`。\n",
    "\n",
    "3. **生成位置ID**：对于Transformer模型而言，位置ID极为关键，它为模型提供了序列中各个元素的位置信息。如果没有直接提供位置ID但提供了注意力掩码，该函数将依据注意力掩码来生成位置ID。\n",
    "\n",
    "4. **更新模型输入**：根据是否提供了`inputs_embeds`以及是否利用了先前的键值对，该函数决定使用哪种类型的输入，并将位置ID、先前的键值对、是否使用缓存以及注意力掩码等信息综合到模型输入中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def prepare_inputs_for_generation(\n",
    "            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n",
    "        ):\n",
    "        # 调整输入以匹配注意力掩码或过去的键值长度\n",
    "        def adjust_input_ids(input_ids, attention_mask, past_length):\n",
    "            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n",
    "                return input_ids[:, -(attention_mask.shape[1] - past_length):]\n",
    "            elif past_length < input_ids.shape[1]:\n",
    "                return input_ids[:, past_length:]\n",
    "            return input_ids\n",
    "\n",
    "        # 根据 kv 缓存的长度调整输入\n",
    "        if past_key_values is not None and len(past_key_values) > 0 and len(past_key_values[0]) > 0 and len(past_key_values[0][0].shape) > 2:\n",
    "            cache_length = past_length = past_key_values[0][0].shape[2]\n",
    "            max_cache_length = None\n",
    "\n",
    "            input_ids = adjust_input_ids(input_ids, attention_mask, past_length)\n",
    "\n",
    "            if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:\n",
    "                attention_mask = attention_mask[:, -max_cache_length:]\n",
    "        \n",
    "        # 按照注意力掩码生成位置ID\n",
    "        position_ids = kwargs.get(\"position_ids\", None)\n",
    "        if attention_mask is not None and position_ids is None:\n",
    "            position_ids = attention_mask.long().cumsum(-1) - 1\n",
    "            position_ids.masked_fill_(attention_mask == 0, 1)\n",
    "            if past_key_values:\n",
    "                position_ids = position_ids[:, -input_ids.shape[1]:]\n",
    "       \n",
    "        # 更新模型输入\n",
    "        model_inputs = {\"inputs_embeds\": inputs_embeds} if inputs_embeds is not None and past_key_values is None else {\"input_ids\": input_ids}\n",
    "        \n",
    "        model_inputs.update(\n",
    "            {\n",
    "                \"position_ids\": position_ids,\n",
    "                \"past_key_values\": past_key_values,\n",
    "                \"use_cache\": kwargs.get(\"use_cache\"),\n",
    "                \"attention_mask\": attention_mask,\n",
    "            }\n",
    "        )\n",
    "        return model_inputs    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MiniCPM 采用了 tie-Embedding 的方式，即词嵌入层和输出层共享参数。这种方式可以减少模型的参数量，提高模型的训练效率。所以需要有获取和设置输入输出词嵌入层的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MiniCPMForCausalLM`类通过继承`MiniCPMPreTrainedModel`继承了基础属性。在其构造函数`__init__`中，执行了以下几个关键步骤：\n",
    "\n",
    "1. **初始化父类**：通过`super().__init__(config)`调用父类构造函数，确保 config 被正确初始化。\n",
    "2. **构建模型核心**：实例化`MiniCPMModel`作为模型的核心组件。\n",
    "3. **定义线性层**：根据配置中的`vocab_size`确定词汇表的大小，并定义一个线性层`lm_head`。该线性层负责将隐藏层状态映射到词汇表上的得分（即logits），并明确指出不使用偏置项（`bias=False`），直接复用输入 Emb 层（读取权重时手动实现）。\n",
    "\n",
    "在`forward`方法中，执行了以下几个关键步骤：\n",
    "\n",
    "1. **确定输出内容**：依据配置来决定是否输出注意力权重和隐藏层状态。\n",
    "2. **执行前向传播**：调用`self.model`进行实际的前向传播计算，获取最后一层的隐藏层状态。\n",
    "3. **转换为logits**：通过`lm_head`将隐藏层状态转换为logits。\n",
    "4. **计算损失**：如果提供了标签，则根据这些标签计算交叉熵损失，这一步骤对模型的训练至关重要。\n",
    "5. **返回结果**：根据`return_dict`的设置，决定是返回一个包含所有输出的元组，还是返回一个命名的输出对象`CausalLMOutputWithPast`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "          \n",
    "class MiniCPMForCausalLM(MiniCPMPreTrainedModel):\n",
    "    _tied_weights_keys = [\"lm_head.weight\"]\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.model = MiniCPMModel(config)\n",
    "        self.vocab_size = config.vocab_size\n",
    "        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        # self.post_init()\n",
    "\n",
    "    def get_input_embeddings(self):\n",
    "        return self.model.embed_tokens\n",
    "\n",
    "    def set_input_embeddings(self, value):\n",
    "        self.model.embed_tokens = value\n",
    "\n",
    "    def get_output_embeddings(self):\n",
    "        return self.lm_head\n",
    "\n",
    "    def set_output_embeddings(self, new_embeddings):\n",
    "        self.lm_head = new_embeddings\n",
    "\n",
    "    def set_decoder(self, decoder):\n",
    "        self.model = decoder\n",
    "\n",
    "    def get_decoder(self):\n",
    "        return self.model\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        labels: Optional[torch.LongTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ) -> Union[Tuple, CausalLMOutputWithPast]:\n",
    "\n",
    "        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
    "        output_hidden_states = (\n",
    "            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
    "        )\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        # 调用模型\n",
    "        outputs = self.model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "        \n",
    "        # 获取最后一层隐藏状态，并通过线性层（lm_head）转换为logits\n",
    "        hidden_states = outputs.last_hidden_state\n",
    "        logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))\n",
    "        logits = logits.float()\n",
    "        \n",
    "        loss = None\n",
    "        # 如果存在标签，则进行损失计算\n",
    "        if labels is not None:\n",
    "            # 对logits和labels进行错位，以便预测下一个token\n",
    "            shift_logits = logits[..., :-1, :].contiguous()\n",
    "            shift_labels = labels[..., 1:].contiguous()\n",
    "            # 为交叉熵损失计算准备，将tokens展平\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n",
    "            shift_labels = shift_labels.view(-1)\n",
    "            shift_labels = shift_labels.to(shift_logits.device)\n",
    "            # 计算交叉熵损失\n",
    "            loss = loss_fct(shift_logits, shift_labels)\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (logits,) + outputs[1:]\n",
    "            return (loss,) + output if loss is not None else output\n",
    "\n",
    "        return CausalLMOutputWithPast(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            past_key_values=outputs.past_key_values,\n",
    "            hidden_states=outputs.hidden_states,\n",
    "            attentions=outputs.attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义一个名为 `generate` 的方法，用于生成文本序列。该方法设计灵活，既能进行确定性的最大概率生成，也能通过随机采样产生更加多样化的输出。通过调整方法参数，用户可以在生成的质量与多样性之间做出权衡。\n",
    "\n",
    "1. **初始化缓存**：若启用缓存且未提供过去的键值对，则该方法会初始化一个空的键值对缓存。\n",
    "\n",
    "2. **准备输入**：计算批次大小，并初始化两个标志变量：`finished` 用于标记每个序列是否完成生成，`unfinished_sequences` 用于标记每个序列是否尚未完成。\n",
    "\n",
    "3. **获取 pad_token_id**：从 tokenizer 中获取 pad token 的 ID，该 ID 将用于后续填充生成的序列。\n",
    "\n",
    "4. **生成循环**：最多循环 `max_new_tokens` 次，每次循环生成一个新的 token。循环内部操作如下：\n",
    "   - 准备当前步骤的输入，并通过模型获取 logits。\n",
    "   - 若指定了 `top_k`，则将 logits 中非 top_k 的值设置为负无穷大，以便在采样时忽略它们。\n",
    "   - 根据 `do_sample` 参数决定是通过采样还是选择最大概率的 token 作为下一个 token。\n",
    "   - 更新输入序列，将新生成的 token 添加到输入序列中。\n",
    "   - 若提供了 `attention_mask`，则更新它以包括新的 token。\n",
    "   - 更新 `finished` 和 `unfinished_sequences` 标志，以标记哪些序列已完成或仍未完成。\n",
    "   - 检查是否所有序列都已完成，若是，则终止循环。\n",
    "\n",
    "5. **返回生成的序列**：最终，该方法返回包含原始及新生成 token 的输入序列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    @torch.no_grad()\n",
    "    def generate(self, input_ids, max_new_tokens=1024, temperature=1.0, top_k=None, use_cache=False, past_key_values=None, tokenizer=None, do_sample=False, **model_kwargs):\n",
    "        if use_cache and past_key_values is None:\n",
    "            # 初始化 kv 缓存\n",
    "            past_key_values = ([], [])\n",
    "            model_kwargs[\"past_key_values\"] = past_key_values\n",
    "        batch_size = input_ids.size(0)\n",
    "        # 初始化完成标志和未完成序列标志\n",
    "        finished = torch.zeros(batch_size, dtype=torch.bool).to(input_ids.device)\n",
    "        unfinished_sequences = torch.ones(batch_size, dtype=torch.bool).to(input_ids.device)\n",
    "        # 获取 pad_token_id 用于填充\n",
    "        pad_token_id = tokenizer.pad_token_id  # 提前获取 pad_token_id\n",
    "\n",
    "        for _ in range(max_new_tokens):\n",
    "            # 准备生成的输入\n",
    "            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n",
    "\n",
    "            logits = self(**model_inputs).logits[:, -1, :] / temperature  # Apply temperature\n",
    "            \n",
    "            if top_k is not None:\n",
    "                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]\n",
    "                logits[indices_to_remove] = -float('Inf')\n",
    "    \n",
    "            if do_sample:\n",
    "                probs = F.softmax(logits, dim=-1)\n",
    "                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n",
    "            else:\n",
    "                next_tokens = torch.argmax(logits, dim=-1)\n",
    "            \n",
    "            # 更新未完成序列的 next_tokens \n",
    "            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (~unfinished_sequences)        \n",
    "            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
    "            if \"attention_mask\" in model_kwargs:\n",
    "                # 更新 attention_mask\n",
    "                attention_mask = model_kwargs[\"attention_mask\"]\n",
    "                model_kwargs[\"attention_mask\"] = torch.cat(\n",
    "                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1\n",
    "                )\n",
    "            # 更新完成和未完成的序列标志\n",
    "            finished |= (next_tokens.squeeze(-1) == tokenizer.eos_token_id)\n",
    "            unfinished_sequences &= ~finished\n",
    "            \n",
    "            # 如果所有序列都完成，则停止生成\n",
    "            if finished.all():\n",
    "                break\n",
    "\n",
    "        return input_ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所以整个 MiniCPM 的非 Embedding 参数量为：\n",
    "\n",
    "整个模型有 40 个 decoder 层，每个 decoder 层由一个 21M 参数的 attention 层和一个 39M 参数的 MLP 层组成，共约 61M 参数。\n",
    "所以总参数量为 `61,062,144 * 40 = 2,442,485,760`，约 2.4B。\n",
    "\n",
    "Embedding 层后的参数为：`122753 * 2304 = 282,822,912`， 约 282M。\n",
    "\n",
    "考虑 Embedding 层后的总参数为： `2,442,485,760 + 282,822,912 = 2,725,308,672`，约 2.7B 参数。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aiLLM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
